{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "FLAX Language Model",
      "provenance": [],
      "collapsed_sections": [
        "h4-0gYhDya28"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bFm75GMEjt5l"
      },
      "source": [
        "# Flax Language Model Example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lvRVndZpj614"
      },
      "source": [
        "A language model learns a probability distribution over sentences from a given corpus by modelling each subsequent token (character, word, word-piece, etc.) as an autoregressive model over past observed tokens.  This conditional distribution is commonly approximated using a \"Transformer\" decoder-stack.\n",
        "\n",
        "Here we adapt the main training script for FLAX's _lm1b_ language model example for running live in the colab environment.  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JrA2vsXegBT5"
      },
      "source": [
        "# Preparatory Steps"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jx-4-KxVYb0d"
      },
      "source": [
        "## Upgrade Local JAX + FLAX Packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "w8wUZtc9o0BQ",
        "colab": {}
      },
      "source": [
        "# Install the newest JAX and FLAX versions.\n",
        "!pip install --upgrade -q jax==0.1.61 jaxlib==0.1.42 flax==0.1.0rc2\n",
        "# Grab flax example code\n",
        "!git clone -b master https://github.com/google/flax.git flaxrepo"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HD6EdBfqpmKt"
      },
      "source": [
        "## TPU Configuration"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yQKmvEVBnPFb"
      },
      "source": [
        "⚠️ Make sure the Colab Runtime is set to Accelerator: TPU.<br>\n",
        "Menu: _Runtime --> Change runtime type_<br>\n",
        "Popup: _Hardware Accelerator --> TPU_<br>\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "atYp8a7uYaC_",
        "colab": {}
      },
      "source": [
        "import requests\n",
        "import os\n",
        "if 'TPU_DRIVER_MODE' not in globals():\n",
        "  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'\n",
        "  resp = requests.post(url)\n",
        "  TPU_DRIVER_MODE = 1\n",
        "# The following is required to use TPU Driver as JAX's backend.\n",
        "import os\n",
        "from jax.config import config\n",
        "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n",
        "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n",
        "print(config.FLAGS.jax_backend_target)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vZhvBULSgHOO"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "RaXDJf-zhAu2",
        "colab": {}
      },
      "source": [
        "import functools\n",
        "import itertools\n",
        "import os\n",
        "import time\n",
        "\n",
        "import flax\n",
        "from flax import jax_utils\n",
        "from flax import nn\n",
        "from flax import optim\n",
        "from flax.metrics import tensorboard\n",
        "from flax.training import checkpoints\n",
        "from flax.training import common_utils\n",
        "\n",
        "import jax\n",
        "from jax import random\n",
        "import jax.nn\n",
        "import jax.numpy as jnp\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "import tensorflow_datasets as tfds\n",
        "tf.enable_v2_behavior()\n",
        "\n",
        "# We directly import the FLAX Language Model example code.\n",
        "from flaxrepo.examples.lm1b import decode\n",
        "from flaxrepo.examples.lm1b import input_pipeline\n",
        "from flaxrepo.examples.lm1b import models"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WU-VT1W3gROj"
      },
      "source": [
        "## Hyperparameters and Configuration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ag7k2wX-gVzf",
        "colab": {}
      },
      "source": [
        "# Make a local directory to store run data and checkpoints, etc.\n",
        "!mkdir run_1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "bTaKQ6YNgPwj",
        "colab": {}
      },
      "source": [
        "model_dir = '/content/run_1'  # Directory to store data in.\n",
        "save_checkpoints = True       # Save local checkpoints?\n",
        "restore_checkpoints = True    # Restore from last local checkpoint?     \n",
        "checkpoint_freq = 5000        # How often to save checkpoints\n",
        "\n",
        "num_train_steps = 500000      # Max number of training steps.\n",
        "eval_frequency = 1000         # How often to run model evaluation.\n",
        "num_eval_steps = 20           # Number of steps to take during evaluation.\n",
        "random_seed = 0               # JAX PRNG random seed.\n",
        "learning_rate = 0.05          # Base learning rate.\n",
        "weight_decay = 1e-1           # AdamW-style relative weight decay factor.\n",
        "batch_size = 256              # \"Target\" Batch size.\n",
        "max_target_length = 256       # Maximum input length.\n",
        "max_eval_target_length = 256  # Maximum eval-set input length.\n",
        "\n",
        "lm_emb_dim = 512              # LM initial token embedding dimension.\n",
        "lm_num_heads = 8              # Number of heads in decoder layers.\n",
        "lm_num_layers = 6             # Number of decoder layers.\n",
        "lm_qkv_dim = 512              # Decoder query/key/value depth.\n",
        "lm_mlp_dim = 2048             # Feedforward (MLP) layer depth.\n",
        "\n",
        "prompt_str = 'The British '   # Prompt for LM Inference.\n",
        "sampling_temperature = 0.6    # Temperature to sample LM at.\n",
        "sampling_top_k = 20           # If > 0, use TopK temperature sampling.\n",
        "max_predict_token_length = 50 # Maximum number of subword tokens to predict."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RFaREjBfjJ0B"
      },
      "source": [
        "# Datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "giOlB5CfCDaK"
      },
      "source": [
        "## Wikitext-2 Dataset (FAST)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lPjbnEVuQkTn"
      },
      "source": [
        "Instead of having to wait on locally building the LM1B dataset, we can instead ingest the smaller [Wikitext-2](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) dataset extracted from a small subset of the english wikipedia to train on.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jFWQEQ-sCH2r",
        "colab": {}
      },
      "source": [
        "!wget --quiet https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip\n",
        "!unzip wikitext-2-raw-v1.zip"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TSBG2VX0W1BQ",
        "colab": {}
      },
      "source": [
        "def preprocess_ds(path):\n",
        "  \"\"\"Extract content sentences from Wikitext-2 dataset.\"\"\"\n",
        "  dataset = tf.data.TextLineDataset(path)\n",
        "  # Drop article headers.\n",
        "  def content_filter(source):\n",
        "    return tf.logical_not(tf.strings.regex_full_match(\n",
        "        source, \n",
        "        '([[:space:]][=])+.+([[:space:]][=])+[[:space:]]*'))\n",
        "  dataset = dataset.filter(content_filter)\n",
        "\n",
        "  # Split paragraphs to lines.\n",
        "  dataset = dataset.map(lambda x: tf.strings.split(x, ' . '))\n",
        "  dataset = dataset.unbatch()\n",
        "\n",
        "  # Remove blank lines.\n",
        "  def min_filter(min_len):\n",
        "    def filter_fn(source):\n",
        "      return tf.greater(tf.strings.length(source), tf.constant(min_len))\n",
        "    return filter_fn\n",
        "  dataset = dataset.filter(min_filter(1))\n",
        "\n",
        "  return dataset\n",
        "\n",
        "# Get the raw train and eval datasets.\n",
        "train_ds = preprocess_ds('/content/wikitext-2-raw/wiki.train.raw')\n",
        "eval_ds = preprocess_ds('/content/wikitext-2-raw/wiki.valid.raw')\n",
        "\n",
        "# Build subword tokenizer.\n",
        "try:\n",
        "  # If we already ran this cell, reload the cached subword vocab file.\n",
        "  encoder = tfds.features.text.SubwordTextEncoder.load_from_file('wikitext2')\n",
        "except tf.errors.NotFoundError:\n",
        "  # Build subword tokenizer from data. Takes ~1 minute.\n",
        "  encoder = tfds.features.text.SubwordTextEncoder.build_from_corpus(\n",
        "      (x._numpy() for x in train_ds),\n",
        "      target_vocab_size=2**13,\n",
        "      max_corpus_chars=10**6)\n",
        "  encoder.save_to_file('wikitext2')\n",
        "\n",
        "# Encode strings with subword tokenizer.\n",
        "def tf_encode(x):\n",
        "  result = tf.py_function(lambda s: tf.constant(encoder.encode(s.numpy())), \n",
        "                          [x,], \n",
        "                          tf.int32)\n",
        "  result.set_shape([None])\n",
        "  return result\n",
        "train_ds=train_ds.map(tf_encode)\n",
        "eval_ds=eval_ds.map(tf_encode)\n",
        "\n",
        "# Created zero-padded length-bucketed batches.\n",
        "train_ds = input_pipeline.lm1b_preprocess(train_ds,\n",
        "                training=True,\n",
        "                n_devices=jax.local_device_count(),\n",
        "                max_target_length=256,\n",
        "                max_eval_target_length=256,\n",
        "                batch_size=256,\n",
        "                drop_remainder=True)\n",
        "\n",
        "eval_ds = input_pipeline.lm1b_preprocess(eval_ds,\n",
        "                training=False,\n",
        "                n_devices=jax.local_device_count(),\n",
        "                max_target_length=256,\n",
        "                max_eval_target_length=256,\n",
        "                batch_size=256,\n",
        "                drop_remainder=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "h4-0gYhDya28"
      },
      "source": [
        "## ⚠️ LM1B Dataset (SLOW)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UflMmpvD_NYN"
      },
      "source": [
        "The LM1B dataset is fairly large and requires significant upfront preprocessing.  Doing it on a colab VM is possible but can be frustrating as it will take several hours to finish during which time the VM could reset.\n",
        "\n",
        "We strongly recommend downloading and preparing the dataset on a cloud instance and storing the prepared dataset on a [GCS Bucket](https://www.tensorflow.org/datasets/gcs).  Another alternative is preparing the dataset on a local machine and uploading it to a Google Drive folder which can be mounted on colab.\n",
        "\n",
        "```python\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "!cp -r /content/drive/My\\ Drive/tensorflow_datasets/lm1b ./tensorflow_datasets/lm1b\n",
        "```\n",
        "\n",
        "More IO documentation at the Colab [IO notebook](https://colab.sandbox.google.com/notebooks/io.ipynb)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "l2oecg8NyJEi",
        "colab": {}
      },
      "source": [
        "# On colab it takes an hour to download and several more hours to preprocess, \n",
        "# and you may need to babysit the colab to keep it alive. If you do this be \n",
        "# sure to copy it to a google drive folder or elsewhere as storage on a \n",
        "# Colab VM is ephemeral!\n",
        "\n",
        "# builder = tfds.builder('lm1b/subwords32k')\n",
        "# builder.download_and_prepare(download_dir='/content/tensorflow_datasets')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "W_oHIU48gA_I",
        "colab": {}
      },
      "source": [
        "# (below commented out to avoid triggering on \"run all\")\n",
        "\n",
        "# # Point to existing local data copied over:\n",
        "# data_dir = '/content/tensorflow_datasets'\n",
        "# # or a GCS Bucket:\n",
        "# # data_dir = \"gs://YOUR_BUCKET_NAME/tensorflow_datasets\"\n",
        "# train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(\n",
        "#       n_devices=jax.local_device_count(),\n",
        "#       data_dir=data_dir,\n",
        "#       batch_size=batch_size,\n",
        "#       dynamic_batching=True,\n",
        "#       max_target_length=max_target_length,\n",
        "#       max_eval_target_length=max_eval_target_length)\n",
        "# vocab_size = info_ds['text'].encoder.vocab_size\n",
        "# encoder = info_ds['text'].encoder"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SXu5ddRqCGM8"
      },
      "source": [
        "# Model "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jiml6MRI1DZA"
      },
      "source": [
        "Defined in the `examples/lm1b/models.py` file."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "apyRChke1she",
        "colab": {}
      },
      "source": [
        "# Init PRNG Stream.\n",
        "rng = random.PRNGKey(random_seed)\n",
        "rng, init_rng = random.split(rng)\n",
        "# We init the first set of dropout PRNG keys, but update it afterwards inside\n",
        "# the main pmap'd training update for performance.\n",
        "dropout_rngs = random.split(rng, jax.local_device_count())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Z6lzno3axMDM"
      },
      "source": [
        "### Model, Optimizer, Learning Rate"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "q3O7Ss6Erg5b",
        "colab": {}
      },
      "source": [
        "@functools.partial(jax.jit, static_argnums=(1, 2))\n",
        "def create_model(key, input_shape, model_kwargs):\n",
        "  \"\"\"\n",
        "  We create a model definition from the top-level Language Model and \n",
        "  passed in hyperparameters.\n",
        "  \"\"\"\n",
        "  module = models.TransformerLM.partial(**model_kwargs)\n",
        "  # We initialize an autoregressive Cache collection for fast, autoregressive\n",
        "  # decoding through the language model's decoder layers.\n",
        "  with nn.attention.Cache().mutate() as cache_def:\n",
        "    # create_by_shape initializes the model parameters.\n",
        "    _, model = module.create_by_shape(key,\n",
        "                                         [(input_shape, jnp.float32)],\n",
        "                                         cache=cache_def)\n",
        "  return model, cache_def\n",
        "\n",
        "# Init model and optimizer.\n",
        "vocab_size = encoder.vocab_size\n",
        "input_shape = (batch_size, max_target_length)\n",
        "transformer_lm_kwargs = {\n",
        "    'vocab_size': vocab_size,\n",
        "    'emb_dim': lm_emb_dim,\n",
        "    'num_heads': lm_num_heads,\n",
        "    'num_layers': lm_num_layers,\n",
        "    'qkv_dim': lm_qkv_dim,\n",
        "    'mlp_dim': lm_mlp_dim,\n",
        "    'max_len': max(max_target_length, max_eval_target_length)\n",
        "}\n",
        "model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs)\n",
        "\n",
        "\n",
        "def create_optimizer(model, learning_rate):\n",
        "  \"\"\"\n",
        "  Here we define the AdamW optimizer we'll use.\n",
        "  \"\"\"\n",
        "  optimizer_def = optim.Adam(\n",
        "      learning_rate,\n",
        "      beta1=0.9,\n",
        "      beta2=0.98,\n",
        "      eps=1e-9,\n",
        "      weight_decay=weight_decay)\n",
        "  optimizer = optimizer_def.create(model)\n",
        "  optimizer = optimizer.replicate()\n",
        "  return optimizer\n",
        "\n",
        "# Build an optimizer from the model.\n",
        "optimizer = create_optimizer(model, learning_rate)\n",
        "# Don't keep a copy of the initial model object.\n",
        "# if needed, we instead access the model directly via optimizer.target\n",
        "del model\n",
        "\n",
        "\n",
        "def create_learning_rate_scheduler(base_learning_rate=0.5, warmup_steps=8000):\n",
        "  \"\"\"Define our learning rate schedule.\"\"\"\n",
        "  def step_fn(step):\n",
        "    return jnp.asarray(\n",
        "        base_learning_rate * \n",
        "        jnp.minimum(1.0, step / warmup_steps) /\n",
        "        jnp.sqrt(jnp.maximum(step, warmup_steps)), dtype=jnp.float32)\n",
        "  return step_fn\n",
        "\n",
        "learning_rate_fn = create_learning_rate_scheduler(\n",
        "    base_learning_rate=learning_rate)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wR39MzcOxQMV"
      },
      "source": [
        "### Loss Function and Auxiliary Metrics"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WXj9NF0Dq_rW",
        "colab": {}
      },
      "source": [
        "def compute_weighted_cross_entropy(logits, targets, weights=None):\n",
        "  \"\"\"Compute weighted cross entropy and entropy for log probs and targets.\n",
        "\n",
        "  Args:\n",
        "   logits: [batch, length, num_classes] float array.\n",
        "   targets: categorical targets [batch, length] int array.\n",
        "   weights: None or array of shape [batch x length]\n",
        "\n",
        "  Returns:\n",
        "    Tuple of scalar loss and batch normalizing factor.\n",
        "  \"\"\"\n",
        "  if logits.ndim != targets.ndim + 1:\n",
        "    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %\n",
        "                     (str(logits.shape), str(targets.shape)))\n",
        "  onehot_targets = common_utils.onehot(targets, logits.shape[-1])\n",
        "  loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)\n",
        "  normalizing_factor = onehot_targets.sum()\n",
        "  if weights is not None:\n",
        "    loss = loss * weights\n",
        "    normalizing_factor = weights.sum()\n",
        "\n",
        "  return loss.sum(), normalizing_factor\n",
        "\n",
        "\n",
        "def compute_weighted_accuracy(logits, targets, weights=None):\n",
        "  \"\"\"Compute weighted accuracy for log probs and targets.\n",
        "\n",
        "  Args:\n",
        "   logits: [batch, length, num_classes] float array.\n",
        "   targets: categorical targets [batch, length] int array.\n",
        "   weights: None or array of shape [batch x length]\n",
        "\n",
        "  Returns:\n",
        "    Tuple of scalar accuracy and batch normalizing factor.\n",
        "  \"\"\"\n",
        "  if logits.ndim != targets.ndim + 1:\n",
        "    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %\n",
        "                     (str(logits.shape), str(targets.shape)))\n",
        "  loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)\n",
        "  normalizing_factor = jnp.prod(logits.shape[:-1])\n",
        "  if weights is not None:\n",
        "    loss = loss * weights\n",
        "    normalizing_factor = weights.sum()\n",
        "\n",
        "  return loss.sum(), normalizing_factor\n",
        "\n",
        "\n",
        "def compute_metrics(logits, labels, weights):\n",
        "  \"\"\"Compute summary metrics.\"\"\"\n",
        "  loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights)\n",
        "  acc, _ = compute_weighted_accuracy(logits, labels, weights)\n",
        "  metrics = {\n",
        "      'loss': loss,\n",
        "      'accuracy': acc,\n",
        "      'denominator': weight_sum,\n",
        "  }\n",
        "  metrics = jax.lax.psum(metrics, axis_name='batch')\n",
        "  return metrics"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "H0uzc41CzyBG"
      },
      "source": [
        "### Main training, evaluation, and inference functions."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "-efZctbyrCh-",
        "colab": {}
      },
      "source": [
        "def train_step(optimizer, inputs, learning_rate_fn, dropout_rng=None):\n",
        "  \"\"\"Perform a single training step.\"\"\"\n",
        "  weights = jnp.where(inputs > 0, 1, 0)\n",
        "\n",
        "  # We handle PRNG splitting inside the top pmap, rather\n",
        "  # than handling it outside in the training loop - doing the\n",
        "  # latter can add some stalls to the devices.\n",
        "  dropout_rng, new_dropout_rng = random.split(dropout_rng)\n",
        "\n",
        "  def loss_fn(model):\n",
        "    \"\"\"Loss function used for training.\"\"\"\n",
        "    with nn.stochastic(dropout_rng):\n",
        "      logits = model(inputs, train=True)\n",
        "    loss, weight_sum = compute_weighted_cross_entropy(logits, inputs, weights)\n",
        "    mean_loss = loss / weight_sum\n",
        "    return mean_loss, logits\n",
        "\n",
        "  step = optimizer.state.step\n",
        "  lr = learning_rate_fn(step)\n",
        "  new_optimizer, _, logits = optimizer.optimize(loss_fn, learning_rate=lr)\n",
        "  metrics = compute_metrics(logits, inputs, weights)\n",
        "  metrics['learning_rate'] = lr\n",
        "\n",
        "  return new_optimizer, metrics, new_dropout_rng\n",
        "\n",
        "# parallelize the training step with JAX's pmap.\n",
        "p_train_step = jax.pmap(\n",
        "    functools.partial(train_step, learning_rate_fn=learning_rate_fn),\n",
        "    axis_name='batch')\n",
        "\n",
        "\n",
        "def eval_step(model, inputs):\n",
        "  weights = jnp.where(inputs > 0, 1, 0)\n",
        "  logits = model(inputs, train=False)\n",
        "  return compute_metrics(logits, inputs, weights)\n",
        "\n",
        "# parallelize the evaluation step with JAX's pmap.\n",
        "p_eval_step = jax.pmap(eval_step, axis_name='batch')\n",
        "\n",
        "\n",
        "def predict_step(inputs, model, cache, prng_key):\n",
        "  \"\"\"Fast sampling of language model from prompt.\"\"\"\n",
        "  prefix_len = inputs.shape[1]\n",
        "  pad_len = max_predict_token_length - prefix_len\n",
        "  padded_inputs = jnp.pad(inputs, jnp.array([[0, 0], [0, pad_len]]))\n",
        "\n",
        "  def tokens_ids_to_logits(ids, cache):\n",
        "    \"\"\"Token slice to logits from decoder model.\"\"\"\n",
        "    with cache.mutate() as new_cache:\n",
        "      logits = model(ids, shift=False, train=False, cache=new_cache)\n",
        "    # Remove singleton sequence-length dimension from model.\n",
        "    # [batch, 1, vocab] --> [batch, vocab]\n",
        "    logits = logits.squeeze(axis=1)\n",
        "    return logits, new_cache\n",
        "\n",
        "  sampled_seqs = decode.temperature_sample(\n",
        "      padded_inputs,\n",
        "      cache,\n",
        "      tokens_ids_to_logits,\n",
        "      prng_key,\n",
        "      temperature=sampling_temperature,\n",
        "      topk=sampling_top_k,\n",
        "      eos_token=2**16)  # No EOS tokens used in default lm1b dataset encoding.\n",
        "\n",
        "  return sampled_seqs\n",
        "\n",
        "# parallelize the fast autoregressive sampler with JAX's pmap.\n",
        "p_pred_step = jax.pmap(predict_step, axis_name='batch')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "K82k6M5IfKuT"
      },
      "source": [
        "# Tensorboard Logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WMjrDy-Dd_WK",
        "colab": {}
      },
      "source": [
        "# Load the TensorBoard notebook extension.\n",
        "%load_ext tensorboard"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "9jzyNEMBeOAJ",
        "colab": {}
      },
      "source": [
        "# Launch an inline tensorboard panel.\n",
        "%tensorboard --logdir /content/run_1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qUZdSfR_fGcw"
      },
      "source": [
        "# Main Training Loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "v_HuG_IhePxw",
        "colab": {}
      },
      "source": [
        "# Summary writers for tensorboard.\n",
        "train_summary_writer = tensorboard.SummaryWriter(\n",
        "    os.path.join(model_dir, 'train'))\n",
        "eval_summary_writer = tensorboard.SummaryWriter(\n",
        "    os.path.join(model_dir, 'eval'))\n",
        "\n",
        "# Initialize training dataset iterator.\n",
        "train_iter = iter(train_ds)\n",
        "start_step = 0\n",
        "\n",
        "if restore_checkpoints:\n",
        "  # Restore unreplicated optimizer + model state from last checkpoint.\n",
        "  optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)\n",
        "  # Grab last step from the first of the optimizer replicas.\n",
        "  start_step = int(optimizer.state.step[0])\n",
        "\n",
        "metrics_all = []    # We aggregate and average training metrics here.\n",
        "tick = time.time()  # Initialize step timer.\n",
        "\n",
        "print('Compiling XLA programs for different input shapes,'\n",
        "      ' this can take 5-10 minutes.')\n",
        "for step, batch in zip(range(start_step, num_train_steps), train_iter):\n",
        "\n",
        "  # Core training step.\n",
        "  batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))\n",
        "  optimizer, metrics, dropout_rngs = p_train_step(\n",
        "      optimizer, batch, dropout_rng=dropout_rngs)\n",
        "  metrics_all.append(metrics)\n",
        "\n",
        "  # Save a Checkpoint\n",
        "  if step % checkpoint_freq == 0 and step > 0:\n",
        "    if save_checkpoints:\n",
        "      checkpoints.save_checkpoint(model_dir, optimizer, step)\n",
        "\n",
        "  # Periodic metric handling.\n",
        "  if step % eval_frequency == 0 and step > 0:\n",
        "    metrics_all = common_utils.get_metrics(metrics_all)\n",
        "    lr = metrics_all.pop('learning_rate').mean()\n",
        "    metrics_sums = jax.tree_map(jnp.sum, metrics_all)\n",
        "    denominator = metrics_sums.pop('denominator')\n",
        "    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)\n",
        "    summary['learning_rate'] = lr\n",
        "    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)\n",
        "    \n",
        "    # Update step timer.\n",
        "    tock = time.time()\n",
        "    steps_per_sec = eval_frequency / (tock - tick)\n",
        "    tick = tock\n",
        "    train_summary_writer.scalar('steps per second', steps_per_sec, step)\n",
        "\n",
        "    print('train in step: %d, loss: %.4f' %(step, summary['loss']))\n",
        "    for key, val in summary.items():\n",
        "      train_summary_writer.scalar(key, val, step)\n",
        "    train_summary_writer.flush()\n",
        "    # reset metric accumulation for next evaluation cycle.\n",
        "    metrics_all = []\n",
        "\n",
        "\n",
        "    # Eval Metrics -----------------------------------------------------------\n",
        "    eval_metrics = []\n",
        "    for _, eval_batch in zip(range(num_eval_steps), iter(eval_ds)):\n",
        "      eval_batch = common_utils.shard(\n",
        "          jax.tree_map(lambda x: x._numpy(), eval_batch))\n",
        "      metrics = p_eval_step(optimizer.target, eval_batch)\n",
        "      eval_metrics.append(metrics)\n",
        "\n",
        "    eval_metrics = common_utils.get_metrics(eval_metrics)\n",
        "    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)\n",
        "    eval_denominator = eval_metrics_sums.pop('denominator')\n",
        "    eval_summary = jax.tree_map(\n",
        "        lambda x: x / eval_denominator,\n",
        "        eval_metrics_sums)\n",
        "    eval_summary['perplexity'] = jnp.clip(\n",
        "        jnp.exp(eval_summary['loss']), a_max=1.0e4)\n",
        "\n",
        "    print('eval in step: %d, loss: %.4f'%(step, eval_summary['loss']))\n",
        "    for key, val in eval_summary.items():\n",
        "      eval_summary_writer.scalar(key, val, step)\n",
        "    eval_summary_writer.flush()\n",
        "\n",
        "\n",
        "    # Fast inference of prompt extension using trained LM. -------------------\n",
        "    # Update rng stream for prediction.\n",
        "    rng, subrng = jax.random.split(rng)\n",
        "    pred_rngs = random.split(subrng, jax.local_device_count())\n",
        "\n",
        "    # Encode provided text prompt to initialize sampling.\n",
        "    prompt = jnp.array(encoder.encode(prompt_str))\n",
        "    prompt = jax_utils.replicate(prompt)\n",
        "    prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))\n",
        "\n",
        "    # Initialize the autoregressive cache, run prediction loop, collect data.\n",
        "    cache = jax_utils.replicate(\n",
        "        cache_def.initialize_cache((1, max_predict_token_length)))\n",
        "    predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)\n",
        "    predicted = np.array(predicted).reshape(\n",
        "      (predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])\n",
        "\n",
        "    # Write examples for tensorboard.\n",
        "    print(encoder.decode(predicted[0]))\n",
        "    exemplars = ''\n",
        "    for n in range(predicted.shape[0]):\n",
        "      exemplars += encoder.decode(predicted[n]) + '\\n\\n'\n",
        "    eval_summary_writer.text('samples', exemplars, step)\n",
        "    eval_summary_writer.flush()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "b3jul9RswrzD"
      },
      "source": [
        "# Fast inference on the language model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "PnMTrVvPxd3F",
        "colab": {}
      },
      "source": [
        "# Optional - we can skip training and restore from a saved checkpoint.\n",
        "# optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "x4AEVpEye1Ew",
        "colab": {}
      },
      "source": [
        "def predict(rng, prompt_str):\n",
        "  # Update rng stream for prediction.\n",
        "  rng, subrng = jax.random.split(rng)\n",
        "  pred_rngs = random.split(subrng, jax.local_device_count())\n",
        "  # Encode provided text prompt to initialize sampling.\n",
        "  prompt = jnp.array(encoder.encode(prompt_str))\n",
        "  prompt = jax_utils.replicate(prompt)\n",
        "  prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))\n",
        "  # Initialize the autoregressive cache, run prediction loop, collect data.\n",
        "  cache = jax_utils.replicate(\n",
        "      cache_def.initialize_cache((1, max_predict_token_length)))\n",
        "  predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)\n",
        "  predicted = np.array(predicted).reshape(\n",
        "      (predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])\n",
        "  # Print generated sentences.\n",
        "  exemplars = ''\n",
        "  for n in range(predicted.shape[0]):\n",
        "    exemplars += encoder.decode(predicted[n]) + '\\n\\n'\n",
        "  print(exemplars)\n",
        "  # Return rng stream.\n",
        "  return rng"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "NDbBo1CQwFh1",
        "colab": {}
      },
      "source": [
        "rng = predict(rng, \"The kakapo is \")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HVOV1X2MKqOb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
